#' @include generics.R
#' @include preprocessing.R
#' @importFrom stats loess
#' @importFrom methods slot
#' @importFrom SeuratObject .MARGIN .SparseSlots
#' @importFrom utils txtProgressBar setTxtProgressBar
#'
NULL

hvf.methods <- list()

#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Generics
#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Functions
#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Methods for Seurat-defined generics
#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

#' @method FindVariableFeatures default
#' @export
#'
FindVariableFeatures.default <- function(
  object,
  method = VST,
  nfeatures = 2000L,
  verbose = TRUE,
  selection.method = selection.method,
  ...
) {
  if (is_quosure(x = method)) {
    method <- eval(
      expr = quo_get_expr(quo = method),
      envir = quo_get_env(quo = method)
    )
  }
  if (is.character(x = method)) {
    method <- get(x = method)
  }
  if (!is.function(x = method)) {
    stop(
      "'method' must be a function for calculating highly variable features",
      call. = FALSE
    )
  }
  var.gene.ouput <- method(
    data = object,
    nselect = nfeatures,
    verbose = verbose,
    ...
  )
  rownames(x = var.gene.ouput) <- rownames(x = object)
  return(var.gene.ouput)
}


#' @importFrom SeuratObject DefaultLayer Features Key Layers
#'
#' @method FindVariableFeatures StdAssay
#' @export
#'
FindVariableFeatures.StdAssay <- function(
  object,
  method = NULL,
  nfeatures = 2000L,
  layer = NULL,
  span = 0.3,
  clip = NULL,
  key = NULL,
  verbose = TRUE,
  selection.method = 'vst',
  ...
) {
  if (selection.method == 'vst') {
    layer <- layer%||%'counts'
    method <- VST
    key <- 'vst'
  } else if (selection.method %in% c('mean.var.plot', 'mvp')) {
    layer <- layer%||%'data'
    method <- MVP
    key <- 'mvp'
  } else if (selection.method %in% c('dispersion', 'disp')) {
    layer <- layer%||%'data'
    method <- DISP
    key <- 'disp'
  } else if (is.null(x = method) || is.null(x = layer)){
    stop('Custome functions and layers are both required')
  } else {
    key <- NULL
  }
  layer <- Layers(object = object, search = layer)
  if (is.null(x = key)) {
    false <- function(...) {
      return(FALSE)
    }
    key <- if (tryCatch(expr = is_quosure(x = method), error = false)) {
      method
    } else if (is.function(x = method)) {
      substitute(expr = method)
    } else if (is.call(x = enquo(arg = method))) {
      enquo(arg = method)
    } else if (is.character(x = method)) {
      method
    } else {
      parse(text = method)
    }
    key <- .Abbrv(x = as_name(x = key))
  }
  warn.var <- warn.rank <- TRUE
  for (i in seq_along(along.with = layer)) {
    if (isTRUE(x = verbose)) {
      message("Finding variable features for layer ", layer[i])
    }
    data <- LayerData(object = object, layer = layer[i], fast = TRUE)
    hvf.function <- if (inherits(x = data, what = 'V3Matrix')) {
      FindVariableFeatures.default
    } else {
      FindVariableFeatures
    }
    hvf.info <- hvf.function(
      object = data,
      method = method,
      nfeatures = nfeatures,
      span = span,
      clip = clip,
      verbose = verbose,
      ...
    )
    if (warn.var) {
      if (!'variable' %in% colnames(x = hvf.info) || !is.logical(x = hvf.info$variable)) {
        warning(
          "No variable feature indication in HVF info for method ",
          key,
          ", `VariableFeatures` will not work",
          call. = FALSE,
          immediate. = TRUE
        )
        warn.var <- FALSE
      }
    } else if (warn.rank && !'rank' %in% colnames(x = hvf.info)) {
      warning(
        "No variable feature rank in HVF info for method ",
        key,
        ", `VariableFeatures` will return variable features in assay order",
        call. = FALSE,
        immediate. = TRUE
      )
      warn.rank <- FALSE
    }
    colnames(x = hvf.info) <- paste(
      'vf',
      key,
      layer[i],
      colnames(x = hvf.info),
      sep = '_'
    )
    rownames(x = hvf.info) <- Features(x = object, layer = layer[i])
    object[["var.features"]] <- NULL
    object[["var.features.rank"]] <- NULL
    object[[names(x = hvf.info)]] <- NULL
    object[[names(x = hvf.info)]] <- hvf.info
  }
  VariableFeatures(object) <- VariableFeatures(object, nfeatures=nfeatures,method = key)
  return(object)
}

#' @param layer Layer in the Assay5 to pull data from
#' @param features If provided, only compute on given features. Otherwise,
#' compute for all features.
#' @param nfeatures Number of features to mark as the top spatially variable.
#'
#' @method FindSpatiallyVariableFeatures StdAssay
#' @rdname FindSpatiallyVariableFeatures
#' @concept preprocessing
#' @concept spatial
#' @export
#'
FindSpatiallyVariableFeatures.StdAssay <- function(
  object,
  layer = "scale.data",
  spatial.location,
  selection.method = c('markvariogram', 'moransi'),
  features = NULL,
  r.metric = 5,
  x.cuts = NULL,
  y.cuts = NULL,
  nfeatures = nfeatures,
  verbose = TRUE,
  ...
) {
  features <- features %||% rownames(x = object)
  if (selection.method == "markvariogram" && "markvariogram" %in% names(x = Misc(object = object))) {
    features.computed <- names(x = Misc(object = object, slot = "markvariogram"))
    features <- features[! features %in% features.computed]
  }
  data <- GetAssayData(object = object, layer = layer)
  data <- as.matrix(x = data[features, ])
  data <- data[RowVar(x = data) > 0, ]
  if (nrow(x = data) != 0) {
    svf.info <- FindSpatiallyVariableFeatures(
      object = data,
      spatial.location = spatial.location,
      selection.method = selection.method,
      r.metric = r.metric,
      x.cuts = x.cuts,
      y.cuts = y.cuts,
      verbose = verbose,
      ...
    )
  } else {
    svf.info <- c()
  }
  if (selection.method == "markvariogram") {
    if ("markvariogram" %in% names(x = Misc(object = object))) {
      svf.info <- c(svf.info, Misc(object = object, slot = "markvariogram"))
    }
    suppressWarnings(expr = Misc(object = object, slot = "markvariogram") <- svf.info)
    svf.info <- ComputeRMetric(mv = svf.info, r.metric)
    svf.info <- svf.info[order(svf.info[, 1]), , drop = FALSE]
  }
  if (selection.method == "moransi") {
    colnames(x = svf.info) <- paste0("MoransI_", colnames(x = svf.info))
    svf.info <- svf.info[order(svf.info[, 2], -abs(svf.info[, 1])), , drop = FALSE]
  }
  var.name <- paste0(selection.method, ".spatially.variable")
  var.name.rank <- paste0(var.name, ".rank")
  svf.info[[var.name]] <- FALSE
  svf.info[[var.name]][1:(min(nrow(x = svf.info), nfeatures))] <- TRUE
  svf.info[[var.name.rank]] <- 1:nrow(x = svf.info)
  object[names(x = svf.info)] <- svf.info
  return(object)
}


#' @rdname LogNormalize
#' @method LogNormalize default
#'
#' @param margin Margin to normalize over
#' @importFrom SeuratObject .CheckFmargin
#'
#' @export
#'
LogNormalize.default <- function(
  data,
  scale.factor = 1e4,
  margin = 2L,
  verbose = TRUE,
  ...
) {
  margin <- .CheckFmargin(fmargin = margin)
  ncells <- dim(x = data)[margin]
  if (isTRUE(x = verbose)) {
    pb <- txtProgressBar(file = stderr(), style = 3)
  }
  for (i in seq_len(length.out = ncells)) {
    x <- if (margin == 1L) {
      data[i, ]
    } else {
      data[, i]
    }
    xnorm <- log1p(x = x / sum(x) * scale.factor)
    if (margin == 1L) {
      data[i, ] <- xnorm
    } else {
      data[, i] <- xnorm
    }
    if (isTRUE(x = verbose)) {
      setTxtProgressBar(pb = pb, value = i / ncells)
    }
  }
  if (isTRUE(x = verbose)) {
    close(con = pb)
  }
  return(data)
}

#' @method LogNormalize IterableMatrix
#' @export
#'
LogNormalize.IterableMatrix <- function(
    data,
    scale.factor = 1e4,
    margin = 2L,
    verbose = TRUE,
    ...
) {
  data <- BPCells::t(BPCells::t(data) / colSums(data))
  # Log normalization
  data <- log1p(data * scale.factor)
  return(data)
}

#' @importFrom SeuratObject IsSparse
#'
#' @method NormalizeData default
#' @export
#'
NormalizeData.default <- function(
  object,
  normalization.method = c('LogNormalize', 'CLR', 'RC'),
  scale.factor = 1e4,
  cmargin = 2L,
  margin = 1L,
  verbose = TRUE,
  ...
) {
  normalization.method <- normalization.method[1L]
  normalization.method <- match.arg(arg = normalization.method)
  # TODO: enable parallelization via future
  normalized <- switch(
    EXPR = normalization.method,
    'LogNormalize' = {
      if (IsSparse(x = object) && .MARGIN(object = object) == cmargin) {
        .SparseNormalize(
          data = object,
          scale.factor = scale.factor,
          verbose = verbose
        )
      } else {
        LogNormalize(
          data = object,
          scale.factor = scale.factor,
          margin = cmargin,
          verbose = verbose,
          ...
        )
      }
    },
    'CLR' = {
      if (inherits(x = object, what = 'dgTMatrix')) {
        warning('Convert input dgTMatrix into dgCMatrix')
        object <- as(object = object, Class = 'dgCMatrix')
      }
      if (!inherits(x = object, what = 'dgCMatrix') &&
          !inherits(x = object, what = 'matrix')) {
        stop('CLR normalization is only supported for dense and dgCMatrix')
      }
      CustomNormalize(
        data = object,
        custom_function = function(x) {
          return(log1p(x = x/(exp(x = sum(log1p(x = x[x > 0]), na.rm = TRUE)/length(x = x)))))
        },
        margin = margin,
        verbose = verbose
      )
    },
    'RC' = {
      if (!inherits(x = object, what = 'dgCMatrix') &&
          !inherits(x = object, what = 'matrix')) {
        stop('RC normalization is only supported for dense and dgCMatrix')
      }
      RelativeCounts(data = object,
                     scale.factor = scale.factor,
                     verbose = verbose)
    }
  )
  return(normalized)
}

#' @importFrom SeuratObject Cells DefaultLayer DefaultLayer<- Features
#' LayerData LayerData<-
#'
#' @method NormalizeData StdAssay
#' @export
#'
NormalizeData.StdAssay <- function(
  object,
  normalization.method = 'LogNormalize',
  scale.factor = 1e4,
  margin = 1L,
  layer = 'counts',
  save = 'data',
  verbose = TRUE,
  ...
) {
  olayer <- layer <- unique(x = layer)
  layer <- Layers(object = object, search = layer)
  if (length(x = save) != length(x = layer)) {
    save <- make.unique(names = gsub(
      pattern = olayer,
      replacement = save,
      x = layer
    ))
  }
  for (i in seq_along(along.with = layer)) {
    l <- layer[i]
    if (isTRUE(x = verbose)) {
      message("Normalizing layer: ", l)
    }
    LayerData(
      object = object,
      layer = save[i],
      features = Features(x = object, layer = l),
      cells = Cells(x = object, layer = l)
    ) <- NormalizeData(
      object = LayerData(object = object, layer = l, fast = NA),
      normalization.method = normalization.method,
      scale.factor = scale.factor,
      margin = margin,
      verbose = verbose,
      ...
    )
  }
  gc(verbose = FALSE)
  return(object)
}


#' @importFrom SeuratObject StitchMatrix
#'
#' @method ScaleData StdAssay
#' @export
#'
ScaleData.StdAssay <- function(
  object,
  features = NULL,
  layer = 'data',
  vars.to.regress = NULL,
  latent.data = NULL,
  by.layer = FALSE,
  split.by = NULL,
  model.use = 'linear',
  use.umi = FALSE,
  do.scale= TRUE,
  do.center = TRUE,
  scale.max = 10,
  block.size = 1000,
  min.cells.to.block = 3000,
  save = 'scale.data',
  verbose = TRUE,
  ...
) {
  use.umi <- ifelse(test = model.use != 'linear', yes = TRUE, no = use.umi)
  olayer <- layer <- unique(x = layer)
  layer <- Layers(object = object, search = layer)
  if (is.null(layer)) {
    abort(paste0("No layer matching pattern '", olayer, "' found. Please run NormalizeData and retry"))
  }
  if (isTRUE(x = use.umi)) {
    layer <- "counts"
    inform(
      message = "'use.umi' is TRUE, please make sure 'layer' specifies raw counts"
    )
  }
  features <- features %||% VariableFeatures(object = object)
  if (!length(x = features)) {
    features <- Features(x = object, layer = layer)
  }
  if (isTRUE(x = by.layer)) {
    if (length(x = save) != length(x = layer)) {
      save <- make.unique(names = gsub(
        pattern = olayer,
        replacement = save,
        x = layer
      ))
    }
    for (i in seq_along(along.with = layer)) {
      lyr <- layer[i]
      if (isTRUE(x = verbose)) {
        inform(message = paste("Scaling data for layer", sQuote(x = lyr)))
      }
      LayerData(object = object, layer = save[i], ...) <- ScaleData(
        object = LayerData(
          object = object,
          layer = lyr,
          features = features,
          fast = NA
        ),
        features = features,
        vars.to.regress = vars.to.regress,
        latent.data = latent.data,
        split.by = split.by,
        model.use = model.use,
        use.umi = use.umi,
        do.scale = do.scale,
        do.center = do.center,
        scale.max = scale.max,
        block.size = block.size,
        min.cells.to.block = min.cells.to.block,
        verbose = verbose,
        ...
      )
    }
  } else {
    ldata <- if (length(x = layer) > 1L) {
      StitchMatrix(
        x = LayerData(object = object, layer = layer[1L], features = features),
        y = lapply(
          X = layer[2:length(x = layer)],
          FUN = LayerData,
          object = object,
          features = features
        ),
        rowmap = slot(object = object, name = 'features')[features, layer],
        colmap = slot(object = object, name = 'cells')[, layer]
      )
    } else {
      LayerData(object = object, layer = layer, features = features)
    }
    ldata <- ScaleData(
      object = ldata,
      features = features,
      vars.to.regress = vars.to.regress,
      latent.data = latent.data,
      split.by = split.by,
      model.use = model.use,
      use.umi = use.umi,
      do.scale = do.scale,
      do.center = do.center,
      scale.max = scale.max,
      block.size = block.size,
      min.cells.to.block = min.cells.to.block,
      verbose = verbose,
      ...
    )
    LayerData(object = object, layer = save, features = rownames(ldata)) <- ldata
  }
  return(object)
}

#' @rdname VST
#' @method VST default
#' @export
#'
VST.default <- function(
  data,
  margin = 1L,
  nselect = 2000L,
  span = 0.3,
  clip = NULL,
  ...
) {
  .NotYetImplemented()
}

#' @rdname VST
#' @method VST IterableMatrix
#' @importFrom SeuratObject EmptyDF
#' @export
#'
VST.IterableMatrix <- function(
    data,
    margin = 1L,
    nselect = 2000L,
    span = 0.3,
    clip = NULL,
    verbose = TRUE,
    ...
) {
  nfeatures <- nrow(x = data)
  hvf.info <- EmptyDF(n = nfeatures)
  hvf.stats <- BPCells::matrix_stats(
    matrix = data,
    row_stats = 'variance')$row_stats
  # Calculate feature means
  hvf.info$mean <- hvf.stats['mean', ]
  # Calculate feature variance
  hvf.info$variance <- hvf.stats['variance', ]
  hvf.info$variance.expected <- 0L
  not.const <- hvf.info$variance > 0
  fit <- loess(
    formula = log10(x = variance) ~ log10(x = mean),
    data = hvf.info[not.const, , drop = TRUE],
    span = span
  )
  hvf.info$variance.expected[not.const] <- 10 ^ fit$fitted
  feature.mean <- hvf.info$mean
  feature.sd <-  sqrt(x = hvf.info$variance.expected)
  standard.max <- clip %||% sqrt(x = ncol(x = data))
  feature.mean[feature.mean == 0] <- 0.1
  data <- BPCells::min_by_row(mat = data, vals = standard.max*feature.sd + feature.mean)
  data.standard <- (data - feature.mean) / feature.sd
  hvf.info$variance.standardized <- BPCells::matrix_stats(
    matrix = data.standard,
    row_stats = 'variance'
    )$row_stats['variance', ]
  # Set variable features
  hvf.info$variable <- FALSE
  hvf.info$rank <- NA
  vf <- head(
    x = order(hvf.info$variance.standardized, decreasing = TRUE),
    n = nselect
  )
  hvf.info$variable[vf] <- TRUE
  hvf.info$rank[vf] <- seq_along(along.with = vf)
  rownames(x = hvf.info) <- rownames(x = data)
  return(hvf.info)
}

#' @importFrom Matrix rowMeans
#' @importFrom SeuratObject EmptyDF
#'
#' @rdname VST
#' @method VST dgCMatrix
#' @export
#'
VST.dgCMatrix <- function(
  data,
  margin = 1L,
  nselect = 2000L,
  span = 0.3,
  clip = NULL,
  verbose = TRUE,
  ...
) {
  nfeatures <- nrow(x = data)
  hvf.info <- EmptyDF(n = nfeatures)
  # Calculate feature means
  hvf.info$mean <- Matrix::rowMeans(x = data)
  # Calculate feature variance
  hvf.info$variance <- SparseRowVar2(
    mat = data,
    mu = hvf.info$mean,
    display_progress = verbose
  )
  hvf.info$variance.expected <- 0L
  not.const <- hvf.info$variance > 0
  fit <- loess(
    formula = log10(x = variance) ~ log10(x = mean),
    data = hvf.info[not.const, , drop = TRUE],
    span = span
  )
  hvf.info$variance.expected[not.const] <- 10 ^ fit$fitted
  hvf.info$variance.standardized <- SparseRowVarStd(
    mat = data,
    mu = hvf.info$mean,
    sd = sqrt(x = hvf.info$variance.expected),
    vmax = clip %||% sqrt(x = ncol(x = data)),
    display_progress = verbose
  )
  # Set variable features
  hvf.info$variable <- FALSE
  hvf.info$rank <- NA
  vf <- head(
    x = order(hvf.info$variance.standardized, decreasing = TRUE),
    n = nselect
  )
  hvf.info$variable[vf] <- TRUE
  hvf.info$rank[vf] <- seq_along(along.with = vf)
  return(hvf.info)
}

#' @rdname VST
#' @method VST matrix
#' @export
#'
VST.matrix <- function(
  data,
  margin = 1L,
  nselect = 2000L,
  span = 0.3,
  clip = NULL,
  ...
) {
  return(VST(
    data = as.sparse(x = data),
    margin = margin,
    nselect = nselect,
    span = span,
    clip = clip,
    ...
  ))
}

#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Methods for R-defined generics
#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Internal
#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%


#' Calculate dispersion of features
#'
#' @param object Data matrix
#' @param mean.function Function to calculate mean
#' @param dispersion.function Function to calculate dispersion
#' @param num.bin Number of bins to use
#' @param binning.method Method to use for binning. Options are 'equal_width' or 'equal_frequency'
#' @param verbose Display progress
#' @keywords internal
#'
CalcDispersion <- function(
  object,
  mean.function = FastExpMean,
  dispersion.function = FastLogVMR,
  num.bin = 20,
  binning.method = "equal_width",
  verbose = TRUE,
  ...
) {
  if (!inherits(x = object, what = c('dgCMatrix', 'matrix'))) {
    stop('mean.var.plot and dispersion methods only \
     support dense and sparse matrix input')
  }
  if (inherits(x = object, what =  'matrix')) {
    object <- as.sparse(x = object)
  }
  feature.mean <- mean.function(object, verbose)
  feature.dispersion <- dispersion.function(object, verbose)

  names(x = feature.mean) <- names(
    x = feature.dispersion) <- rownames(x = object)
  feature.dispersion[is.na(x = feature.dispersion)] <- 0
  feature.mean[is.na(x = feature.mean)] <- 0
  data.x.breaks <- switch(
    EXPR = binning.method,
    'equal_width' = num.bin,
    'equal_frequency' = c(
      quantile(
        x = feature.mean[feature.mean > 0],
        probs = seq.int(from = 0, to = 1, length.out = num.bin)
      )
    ),
    stop("Unknown binning method: ", binning.method)
  )
  data.x.bin <- cut(x = feature.mean, breaks = data.x.breaks,
                    include.lowest = TRUE)
  names(x = data.x.bin) <- names(x = feature.mean)
  mean.y <- tapply(X = feature.dispersion, INDEX = data.x.bin, FUN = mean)
  sd.y <- tapply(X = feature.dispersion, INDEX = data.x.bin, FUN = sd)
  feature.dispersion.scaled <- (feature.dispersion - mean.y[as.numeric(x = data.x.bin)]) /
    sd.y[as.numeric(x = data.x.bin)]
  names(x = feature.dispersion.scaled) <- names(x = feature.mean)
  hvf.info <- data.frame(
    feature.mean, feature.dispersion, feature.dispersion.scaled)
  rownames(x = hvf.info) <- rownames(x = object)
  colnames(x = hvf.info) <- paste0(
    'mvp.', c('mean', 'dispersion', 'dispersion.scaled'))
  return(hvf.info)
}


#' @importFrom SeuratObject .CalcN
#'
CalcN <- function(object, ...) {
  return(.CalcN(object, ...))
}

#' @method .CalcN IterableMatrix
#' @export
#'
.CalcN.IterableMatrix <- function(object, ...) {
  col_stat <- BPCells::matrix_stats(matrix = object, col_stats = 'mean')$col_stats
  return(list(
    nCount = round(col_stat['mean', ] * nrow(object)),
    nFeature = col_stat['nonzero', ]
  ))
}

#' Find variable features based on dispersion
#'
#' @param data Data matrix
#' @param nselect Number of top features to select based on dispersion values
#' @param verbose Display progress
#' @keywords internal
#'
DISP <- function(
  data,
  nselect = 2000L,
  verbose = TRUE,
  ...
) {
  hvf.info <- CalcDispersion(object = data, verbose = verbose, ...)
  hvf.info$variable <- FALSE
  hvf.info$rank <- NA
  vf <- head(
    x = order(hvf.info$mvp.dispersion, decreasing = TRUE),
    n = nselect
  )
  hvf.info$variable[vf] <- TRUE
  hvf.info$rank[vf] <- seq_along(along.with = vf)
  return(hvf.info)
}

#' @importFrom SeuratObject .CheckFmargin
#'
.FeatureVar <- function(
  data,
  mu,
  fmargin = 1L,
  standardize = FALSE,
  sd = NULL,
  clip = NULL,
  verbose = TRUE
) {
  fmargin <- .CheckFmargin(fmargin = fmargin)
  ncells <- dim(x = data)[-fmargin]
  nfeatures <- dim(x = data)[fmargin]
  fvars <- vector(mode = 'numeric', length = nfeatures)
  if (length(x = mu) != nfeatures) {
    stop("Wrong number of feature means provided")
  }
  if (isTRUE(x = standardize)) {
    clip <- clip %||% sqrt(x = ncells)
    if (length(x = sd) != nfeatures) {
      stop("Wrong number of standard deviations")
    }
  }
  if (isTRUE(x = verbose)) {
    msg <- 'Calculating feature variances'
    if (isTRUE(x = standardize)) {
      msg <- paste(msg, 'of standardized and clipped values')
    }
    message(msg)
    pb <- txtProgressBar(style = 3, file = stderr())
  }
  for (i in seq_len(length.out = nfeatures)) {
    if (isTRUE(x = standardize) && sd[i] == 0) {
      if (isTRUE(x = verbose)) {
        setTxtProgressBar(pb = pb, value = i / nfeatures)
      }
      next
    }
    x <- if (fmargin == 1L) {
      data[i, , drop = TRUE]
    } else {
      data[, i, drop = TRUE]
    }
    x <- x - mu[i]
    if (isTRUE(x = standardize)) {
      x <- x / sd[i]
      x[x > clip] <- clip
    }
    fvars[i] <- sum(x ^ 2) / (ncells - 1L)
    if (isTRUE(x = verbose)) {
      setTxtProgressBar(pb = pb, value = i / nfeatures)
    }
  }
  if (isTRUE(x = verbose)) {
    close(con = pb)
  }
  return(fvars)
}

.Mean <- function(data, margin = 1L) {
  nout <- dim(x = data)[margin]
  nobs <- dim(x = data)[-margin]
  means <- vector(mode = 'numeric', length = nout)
  for (i in seq_len(length.out = nout)) {
    x <- if (margin == 1L) {
      data[i, , drop = TRUE]
    } else {
      data[, i, drop = TRUE]
    }
    means[i] <- sum(x) / nobs
  }
  return(means)
}

.SparseNormalize <- function(data, scale.factor = 1e4, verbose = TRUE) {
  entryname <- .SparseSlots(x = data, type = 'entries')
  p <- slot(object = data, name = .SparseSlots(x = data, type = 'pointers'))
  if (p[1L] == 0) {
    p <- p + 1L
  }
  np <- length(x = p) - 1L
  if (isTRUE(x = verbose)) {
    pb <- txtProgressBar(style = 3L, file = stderr())
  }
  for (i in seq_len(length.out = np)) {
    idx <- seq.int(from = p[i], to = p[i + 1] - 1L)
    xidx <- slot(object = data, name = entryname)[idx]
    slot(object = data, name = entryname)[idx] <- log1p(
      x = xidx / sum(xidx) * scale.factor
    )
    if (isTRUE(x = verbose)) {
      setTxtProgressBar(pb = pb, value = i / np)
    }
  }
  if (isTRUE(x = verbose)) {
    close(con = pb)
  }
  return(data)
}

#' @param data A sparse matrix
#' @param mu A vector of feature means
#' @param fmargin Feature margin
#' @param standardize Standardize matrix rows prior to calculating variances
#' @param sd If standardizing, a vector of standard deviations to
#' standardize with
#' @param clip Set upper bound for standardized variances; defaults to the
#' square root of the number of cells
#' @param verbose Show progress updates
#'
#' @keywords internal
#' @importFrom SeuratObject .CheckFmargin
#'
#' @noRd
#'
.SparseFeatureVar <- function(
  data,
  mu,
  fmargin = 1L,
  standardize = FALSE,
  sd = NULL,
  clip = NULL,
  verbose = TRUE
) {
  fmargin <- .CheckFmargin(fmargin = fmargin)
  if (fmargin != .MARGIN(object = data)) {
    data <- t(x = data)
    fmargin <- .MARGIN(object = data)
  }
  entryname <- .SparseSlots(x = data, type = 'entries')
  p <- slot(object = data, name = .SparseSlots(x = data, type = 'pointers'))
  if (p[1L] == 0) {
    p <- p + 1L
  }
  np <- length(x = p) - 1L
  ncells <- dim(x = data)[-fmargin]
  fvars <- vector(mode = 'numeric', length = np)
  if (length(x = mu) != np) {
    stop("Wrong number of feature means provided")
  }
  if (isTRUE(x = standardize)) {
    clip <- clip %||% sqrt(x = ncells)
    if (length(x = sd) != np) {
      stop("Wrong number of standard deviations provided")
    }
  }
  if (isTRUE(x = verbose)) {
    msg <- 'Calculating feature variances'
    if (isTRUE(x = standardize)) {
      msg <- paste(msg, 'of standardized and clipped values')
    }
    message(msg)
    pb <- txtProgressBar(style = 3, file = stderr())
  }
  for (i in seq_len(length.out = np)) {
    if (isTRUE(x = standardize) && sd[i] == 0) {
      if (isTRUE(x = verbose)) {
        setTxtProgressBar(pb = pb, value = i / np)
      }
      next
    }
    idx <- seq.int(from = p[i], to = p[i + 1L] - 1L)
    xidx <- slot(object = data, name = entryname)[idx] - mu[i]
    nzero <- ncells - length(x = xidx)
    csum <- nzero * ifelse(
      test = isTRUE(x = standardize),
      yes = ((0 - mu[i]) / sd[i]) ^ 2,
      no = mu[i] ^ 2
    )
    if (isTRUE(x = standardize)) {
      xidx <- xidx / sd[i]
      xidx[xidx > clip] <- clip
    }
    fsum <- sum(xidx ^ 2) + csum
    fvars[i] <- fsum / (ncells - 1L)
    if (isTRUE(x = verbose)) {
      setTxtProgressBar(pb = pb, value = i / np)
    }
  }
  if (isTRUE(x = verbose)) {
    close(con = pb)
  }
  return(fvars)
}

#' @importFrom SeuratObject .CheckFmargin
.SparseMean <- function(data, margin = 1L) {
  margin <- .CheckFmargin(fmargin = margin)
  if (margin != .MARGIN(object = data)) {
    data <- t(x = data)
    margin <- .MARGIN(object = data)
  }
  entryname <- .SparseSlots(x = data, type = 'entries')
  p <- slot(object = data, name = .SparseSlots(x = data, type = 'pointers'))
  if (p[1L] == 0) {
    p <- p + 1L
  }
  np <- length(x = p) - 1L
  nobs <- dim(x = data)[-margin]
  means <- vector(mode = 'numeric', length = np)
  for (i in seq_len(length.out = np)) {
    idx <- seq.int(from = p[i], to = p[i + 1L] - 1L)
    means[i] <- sum(slot(object = data, name = entryname)[idx]) / nobs
  }
  return(means)
}

#' @inheritParams stats::loess
#' @param data A matrix
#' @param fmargin Feature margin
#' @param nselect Number of features to select
#' @param clip After standardization values larger than \code{clip} will be set
#' to \code{clip}; default is \code{NULL} which sets this value to the square
#' root of the number of cells
#'
#' @importFrom Matrix rowMeans
#' @importFrom SeuratObject .CheckFmargin
#'
#' @keywords internal
#'
#' @noRd
#'
.VST <- function(
  data,
  fmargin = 1L,
  nselect = 2000L,
  span = 0.3,
  clip = NULL,
  verbose = TRUE,
  ...
) {
  fmargin <- .CheckFmargin(fmargin = fmargin)
  nfeatures <- dim(x = data)[fmargin]
  # TODO: Support transposed matrices
  # nfeatures <- nrow(x = data)
  if (IsSparse(x = data)) {
    mean.func <- .SparseMean
    var.func <- .SparseFeatureVar
  } else {
    mean.func <- .Mean
    var.func <- .FeatureVar
  }
  hvf.info <- SeuratObject::EmptyDF(n = nfeatures)
  # hvf.info$mean <- mean.func(data = data, margin = fmargin)
  hvf.info$mean <- rowMeans(x = data)
  hvf.info$variance <- var.func(
    data = data,
    mu = hvf.info$mean,
    fmargin = fmargin,
    verbose = verbose
  )
  hvf.info$variance.expected <- 0L
  not.const <- hvf.info$variance > 0
  fit <- loess(
    formula = log10(x = variance) ~ log10(x = mean),
    data = hvf.info[not.const, , drop = TRUE],
    span = span
  )
  hvf.info$variance.expected[not.const] <- 10 ^ fit$fitted
  hvf.info$variance.standardized <- var.func(
    data = data,
    mu = hvf.info$mean,
    standardize = TRUE,
    sd = sqrt(x = hvf.info$variance.expected),
    clip = clip,
    verbose = verbose
  )
  hvf.info$variable <- FALSE
  hvf.info$rank <- NA
  vs <- hvf.info$variance.standardized
  vs[vs == 0] <- NA
  vf <- head(
    x = order(hvf.info$variance.standardized, decreasing = TRUE),
    n = nselect
  )
  hvf.info$variable[vf] <- TRUE
  hvf.info$rank[vf] <- seq_along(along.with = vf)
  # colnames(x = hvf.info) <- paste0('vst.', colnames(x = hvf.info))
  return(hvf.info)
}

# hvf.methods$vst <- VST

#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# S4 Methods
#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%




################################################################################
################################# SCTransform ##################################
################################################################################

#' @importFrom SeuratObject Cells as.sparse
#'
#' @method SCTransform IterableMatrix
#' @rdname SCTransform
#' @concept preprocessing
#' @export
SCTransform.IterableMatrix <- function(
    object,
    cell.attr,
    reference.SCT.model = NULL,
    do.correct.umi = TRUE,
    ncells = 5000,
    residual.features = NULL,
    variable.features.n = 3000,
    variable.features.rv.th = 1.3,
    vars.to.regress = NULL,
    latent.data = NULL,
    do.scale = FALSE,
    do.center = TRUE,
    clip.range = c(-sqrt(x = ncol(x = object) / 30), sqrt(x = ncol(x = object) / 30)),
    vst.flavor = 'v2',
    conserve.memory = FALSE,
    return.only.var.genes = TRUE,
    seed.use = 1448145,
    verbose = TRUE,
    ...
) {
  if (!is.null(x = seed.use)) {
    set.seed(seed = seed.use)
  }
  if (!is.null(reference.SCT.model)){
    do.correct.umi <- FALSE
    do.center <- FALSE
  }
  sampled_cells <- sample.int(n = ncol(x = object), size = min(ncells, ncol(x = object)))
  umi <- as.sparse(x = object[, sampled_cells])
  cell.attr <- cell.attr[colnames(x = umi),,drop=FALSE]
  vst.out <- SCTransform(object = umi,
                         cell.attr = cell.attr,
                         reference.SCT.model = reference.SCT.model,
                         do.correct.umi = do.correct.umi,
                         ncells = ncells,
                         residual.features = residual.features,
                         variable.features.n = variable.features.n,
                         variable.features.rv.th = variable.features.rv.th,
                         vars.to.regress = vars.to.regress,
                         latent.data = latent.data,
                         do.scale = do.scale,
                         do.center = do.center,
                         clip.range = clip.range,
                         vst.flavor = vst.flavor,
                         conserve.memory = conserve.memory,
                         return.only.var.genes = return.only.var.genes,
                         seed.use = seed.use,
                         verbose = verbose,
                         ...)
  if (!do.correct.umi) {
    vst.out$umi_corrected <- umi
  }
  return(vst.out)
}


#' @importFrom SeuratObject CreateAssayObject SetAssayData GetAssayData
CreateSCTAssay <- function(vst.out,  do.correct.umi, residual.type, clip.range){
  residual.type <- vst.out[['residual_type']] %||% 'pearson'
  sct.method <- vst.out[['sct.method']]
  assay.out <- CreateAssayObject(counts = vst.out$umi_corrected)

  # set the variable genes
  VariableFeatures(object = assay.out) <- vst.out$variable_features
  # put log1p transformed counts in data
  assay.out <- SetAssayData(
    object = assay.out,
    slot = 'data',
    new.data = log1p(x = GetAssayData(object = assay.out, slot = 'counts'))
  )
  scale.data <- vst.out$y
  assay.out <- SetAssayData(
    object = assay.out,
    slot = 'scale.data',
    new.data = scale.data
  )
  vst.out$y <- NULL
  # save clip.range into vst model
  vst.out$arguments$sct.clip.range <- clip.range
  vst.out$arguments$sct.method <- sct.method
  Misc(object = assay.out, slot = 'vst.out') <- vst.out
  assay.out <- as(object = assay.out, Class = "SCTAssay")
  return (assay.out)
}

#' @importFrom SeuratObject Cells DefaultLayer DefaultLayer<- Features
#' LayerData LayerData<- as.sparse
#'
#' @method SCTransform StdAssay
#' @export
#'
SCTransform.StdAssay <- function(
  object,
  layer = 'counts',
  cell.attr = NULL,
  reference.SCT.model = NULL,
  do.correct.umi = TRUE,
  ncells = 5000,
  residual.features = NULL,
  variable.features.n = 3000,
  variable.features.rv.th = 1.3,
  vars.to.regress = NULL,
  latent.data = NULL,
  do.scale = FALSE,
  do.center = TRUE,
  clip.range = c(-sqrt(x = ncol(x = object) / 30), sqrt(x = ncol(x = object) / 30)),
  vst.flavor = 'v2',
  conserve.memory = FALSE,
  return.only.var.genes = TRUE,
  seed.use = 1448145,
  verbose = TRUE,
  ...) {
  if (!is.null(x = seed.use)) {
    set.seed(seed = seed.use)
  }
  if (!is.null(reference.SCT.model)){
    do.correct.umi <- FALSE
    do.center <- FALSE
  }
  olayer <- layer <- unique(x = layer)
  layers <- Layers(object = object, search = layer)
  dataset.names <- gsub(pattern = paste0(layer, "."), replacement = "", x = layers)
  # loop over layers performing SCTransform() on individual layers
  sct.assay.list <- list()
  # Keep a tab of variable features per chunk
  variable.feature.list <- list()

  for (dataset.index in seq_along(along.with = layers)) {
    l <- layers[dataset.index]
    if (isTRUE(x = verbose)) {
      message("Running SCTransform on layer: ", l)
    }
    all_cells <-  Cells(x = object, layer = l)
    all_features <- Features(x = object, layer = l)
    layer.data <- LayerData(
      object = object,
      layer = l,
      features = all_features,
      cells = all_cells
    )
    local.reference.SCT.model <- NULL
    set.seed(seed = seed.use)
    do.correct.umi.chunk <- FALSE
    sct.function <- if (inherits(x = layer.data, what = 'V3Matrix')) {
      SCTransform.default
    } else {
      SCTransform
    }
    if (is.null(x = cell.attr) && is.null(x = reference.SCT.model)){
      calcn <- CalcN(object = layer.data)
      cell.attr.layer <-  data.frame(umi = calcn$nCount,
                               log_umi = log10(x = calcn$nCount))
      rownames(cell.attr.layer) <- colnames(x = layer.data)
    } else {
      cell.attr.layer <- cell.attr[colnames(x = layer.data),, drop=FALSE]
    }
    if (!"umi" %in% cell.attr.layer && is.null(x = reference.SCT.model)){
      calcn <- CalcN(object = layer.data)
      cell.attr.tmp <-  data.frame(umi = calcn$nCount)
      rownames(cell.attr.tmp) <- colnames(x = layer.data)
      cell.attr.layer$umi <- NA
      cell.attr.layer$log_umi <- NA
      cell.attr.layer[rownames(cell.attr.tmp), "umi"] <- cell.attr.tmp$umi
      cell.attr.layer[rownames(cell.attr.tmp), "log_umi"] <- log10(x = cell.attr.tmp$umi)
    }

    # Step 1: Learn model
    vst.out <- sct.function(object = layer.data,
                            do.correct.umi = TRUE,
                            cell.attr = cell.attr.layer,
                            reference.SCT.model = reference.SCT.model,
                            ncells = ncells,
                            residual.features = residual.features,
                            variable.features.n = variable.features.n,
                            variable.features.rv.th = variable.features.rv.th,
                            vars.to.regress = vars.to.regress,
                            latent.data = latent.data,
                            do.scale = do.scale,
                            do.center = do.center,
                            clip.range = clip.range,
                            vst.flavor = vst.flavor,
                            conserve.memory = conserve.memory,
                            return.only.var.genes = return.only.var.genes,
                            seed.use = seed.use,
                            verbose = verbose,
                            ...)
    min_var <- vst.out$arguments$min_variance
    residual.type <- vst.out[['residual_type']] %||% 'pearson'
    assay.out <- CreateSCTAssay(vst.out = vst.out, do.correct.umi = do.correct.umi, residual.type = residual.type,
                                clip.range = clip.range)

    # If there is no reference model, use the model learned on subset of cells to calculate residuals
    # by setting the learned model as the reference model (local.reference.SCT.model)
    if (is.null(x = reference.SCT.model)) {
      local.reference.SCT.model <- assay.out@SCTModel.list[[1]]
    } else {
      local.reference.SCT.model <- reference.SCT.model
    }
    variable.features <- VariableFeatures(assay.out)

    # once we have the model, just calculate residuals for all cells
    # local.reference.SCT.model set to reference.model if it is non null
    vst_out.reference <- SCTModel_to_vst(SCTModel = local.reference.SCT.model)
    vst_out.reference$gene_attr <- local.reference.SCT.model@feature.attributes
    min_var <- vst_out.reference$arguments$min_variance
    if (min_var == "umi_median"){
      counts.x <- as.sparse(x = layer.data[, sample.int(n = ncol(x = layer.data), size =  min(ncells, ncol(x = layer.data)) )])
      min_var <- (median(counts.x@x)/5)^2
      }

    # Step 2: Use learned model to calculate residuals in chunks
    cells.vector <- 1:ncol(x = layer.data)
    cells.grid <- split(x = cells.vector, f = ceiling(x = seq_along(along.with = cells.vector)/ncells))
    # Single block
    residuals <- list()
    corrected_counts <- list()
    cell_attrs <- list()

    if (length(x = cells.grid) == 1){
      merged.assay <- assay.out
      corrected_counts[[1]] <- GetAssayData(object = assay.out, slot = "data")
      residuals[[1]] <- GetAssayData(object = assay.out, slot = "scale.data")
      cell_attrs[[1]] <- vst_out.reference$cell_attr
      sct.assay.list[[dataset.names[dataset.index]]] <- assay.out
    } else {
      # iterate over chunks to get residuals
      for (i in seq_len(length.out = length(x = cells.grid))) {
        vp <- cells.grid[[i]]
        if (verbose){
          message("Getting residuals for block ", i, "(of ", length(cells.grid), ") for ", dataset.names[[dataset.index]], " dataset")
        }
        counts.vp <- as.sparse(x = layer.data[, vp, drop=FALSE])
        cell.attr.object <- cell.attr.layer[colnames(x = counts.vp),, drop=FALSE]
        vst_out <- vst_out.reference

        vst_out$cell_attr <- cell.attr.object
        vst_out$gene_attr <- vst_out$gene_attr[variable.features,,drop=FALSE]
        if (return.only.var.genes){
          new_residual <- get_residuals(
            vst_out = vst_out,
            umi = counts.vp[variable.features,,drop=FALSE],
            residual_type = "pearson",
            min_variance = min_var,
            res_clip_range = clip.range,
            verbosity =  FALSE
          )
        } else {
          new_residual <- get_residuals(
            vst_out = vst_out,
            umi = counts.vp[all_features,,drop=FALSE],
            residual_type = "pearson",
            min_variance = min_var,
            res_clip_range = clip.range,
            verbosity =  FALSE
          )
        }
        vst_out$y <- new_residual
        corrected_counts[[i]] <- correct_counts(
          x = vst_out,
          umi = counts.vp[all_features,,drop=FALSE],
          verbosity = FALSE# as.numeric(x = verbose) * 2
        )
        residuals[[i]] <- new_residual
        cell_attrs[[i]] <- cell.attr.object
      }
      new.residuals <- Reduce(cbind, residuals)
      corrected_counts <- Reduce(cbind, corrected_counts)
      cell_attrs <- Reduce(rbind, cell_attrs)
      vst_out.reference$cell_attr <- cell_attrs[colnames(new.residuals),,drop=FALSE]
      SCTModel.list <- PrepVSTResults(vst.res = vst_out.reference, cell.names = all_cells)
      SCTModel.list <- list(model1 = SCTModel.list)
      # scale data here as do.center and do.scale are set to FALSE inside
      new.residuals <- ScaleData(
        new.residuals,
        features = NULL,
        vars.to.regress = vars.to.regress,
        latent.data = latent.data,
        model.use = 'linear',
        use.umi = FALSE,
        do.scale = do.scale,
        do.center = do.center,
        scale.max = Inf,
        block.size = 750,
        min.cells.to.block = 3000,
        verbose = verbose
      )
      assay.out <- CreateSCTAssayObject(counts = corrected_counts, scale.data = new.residuals,  SCTModel.list = SCTModel.list)
      assay.out$data <- log1p(x = corrected_counts)
      VariableFeatures(assay.out) <- variable.features
      # one assay per dataset
      if (verbose){
        message("Finished calculating residuals for ", dataset.names[dataset.index])
      }
      sct.assay.list[[dataset.names[dataset.index]]] <- assay.out
      variable.feature.list[[dataset.names[dataset.index]]] <- VariableFeatures(assay.out)
    }
  }
  # Return array by merging everythin
  if (length(x = sct.assay.list) == 1){
    merged.assay <- sct.assay.list[[1]]
    } else {
      vf.list <- lapply(X  = sct.assay.list, FUN = function(object.i) VariableFeatures(object = object.i))
      variable.features.union <- Reduce(f = union, x = vf.list)
      var.features.sorted <- sort(
        x = table(unlist(x = vf.list, use.names = FALSE)),
        decreasing = TRUE
        )
      # select top ranking features
      var.features <- variable.features.union
      # calculate residuals for union of features
      for (layer.name in names(x = sct.assay.list)){
        vst_out <- SCTModel_to_vst(SCTModel = slot(object = sct.assay.list[[layer.name]], name = "SCTModel.list")[[1]])
        all_cells <-  Cells(x = object, layer = paste0(layer, ".", layer.name))
        all_features <- Features(x = object, layer = paste0(layer, ".", layer.name))
        variable.features.target <- intersect(x = rownames(x = vst_out$model_pars_fit), y = var.features)
        variable.features.target <- setdiff(x = variable.features.target, y = VariableFeatures(sct.assay.list[[layer.name]]))
        if (length(x = variable.features.target )<1){
          next
          }
        layer.counts.tmp <- LayerData(
          object = object,
          layer = paste0(layer, ".", layer.name),
          cells = all_cells
          )
        layer.counts.tmp <- as.sparse(x = layer.counts.tmp)
        vst_out$cell_attr <- vst_out$cell_attr[, c("log_umi"), drop=FALSE]
        vst_out$model_pars_fit <- vst_out$model_pars_fit[variable.features.target,,drop=FALSE]
        new_residual <- GetResidualsChunked(
          vst_out = vst_out, 
          layer.counts = layer.counts.tmp,
          residual_type = "pearson", 
          min_variance = min_var, 
          res_clip_range = clip.range,
          verbose = FALSE
        )
        old_residual <- GetAssayData(object = sct.assay.list[[layer.name]], slot = 'scale.data')
        merged_residual <- rbind(old_residual, new_residual)
        merged_residual <- ScaleData(
          merged_residual,
          features = NULL,
          vars.to.regress = vars.to.regress,
          latent.data = latent.data,
          model.use = 'linear',
          use.umi = FALSE,
          do.scale = do.scale,
          do.center = do.center,
          scale.max = Inf,
          block.size = 750,
          min.cells.to.block = 3000,
          verbose = verbose
        )
        sct.assay.list[[layer.name]] <- SetAssayData(object = sct.assay.list[[layer.name]], slot = 'scale.data', new.data = merged_residual)
        VariableFeatures(sct.assay.list[[layer.name]]) <- rownames(x = merged_residual)
      }
      merged.assay <- merge(x = sct.assay.list[[1]], y = sct.assay.list[2:length(sct.assay.list)])
      VariableFeatures(object = merged.assay) <- VariableFeatures(object = merged.assay, use.var.features = FALSE, nfeatures = variable.features.n)

    }
  # set the names of SCTmodels to be layer names
  models <- slot(object = merged.assay, name="SCTModel.list")
  names(models) <- names(x = sct.assay.list)
  slot(object = merged.assay, name="SCTModel.list") <- models
  gc(verbose = FALSE)
  return(merged.assay)
  }

#' Calculate pearson residuals of features not in the scale.data
#'
#' This function calls sctransform::get_residuals.
#'
#' @param object A seurat object
#' @param features Name of features to add into the scale.data
#' @param assay Name of the assay of the seurat object generated by SCTransform
#' @param layer Name (prefix) of the layer to pull counts from
#' @param umi.assay Name of the assay of the seurat object containing UMI matrix
#' and the default is RNA
#' @param clip.range Numeric of length two specifying the min and max values the
#' Pearson residual will be clipped to
#' @param reference.SCT.model reference.SCT.model If a reference SCT model should be used
#' for calculating the residuals. When set to not NULL, ignores the `SCTModel`
#' paramater.
#' @param replace.value Recalculate residuals for all features, even if they are
#' already present. Useful if you want to change the clip.range.
#' @param na.rm For features where there is no feature model stored, return NA
#' for residual value in scale.data when na.rm = FALSE. When na.rm is TRUE, only
#' return residuals for features with a model stored for all cells.
#' @param verbose Whether to print messages and progress bars
#'
#' @return Returns a Seurat object containing Pearson residuals of added
#' features in its scale.data
#'
#' @importFrom sctransform get_residuals
#' @importFrom matrixStats rowAnyNAs
#'
#' @export
#' @concept preprocessing
#'
#' @seealso \code{\link[sctransform]{get_residuals}}
FetchResiduals <- function(
  object,
  features,
  assay = NULL,
  umi.assay = "RNA",
  layer = "counts",
  clip.range = NULL,
  reference.SCT.model = NULL,
  replace.value = FALSE,
  na.rm = TRUE,
  verbose = TRUE) {
  assay <- assay %||% DefaultAssay(object = object)
  if (IsSCT(assay = object[[assay]])) {
    object[[assay]] <- as(object[[assay]], "SCTAssay")
  }
  if (!inherits(x = object[[assay]], what = "SCTAssay")) {
    stop(assay, " assay was not generated by SCTransform")
  }
  sct.models <- levels(x = object[[assay]])
  if (length(sct.models)==1){
    sct.models <- list(sct.models)
  }
  if (length(x = sct.models) == 0) {
    warning("SCT model not present in assay", call. = FALSE, immediate. = TRUE)
    return(object)
  }
  possible.features <- Reduce(f = union, x = lapply(X = sct.models, FUN = function(x) {
    rownames(x = SCTResults(object = object[[assay]], slot = "feature.attributes", model = x))
  }))
  bad.features <- setdiff(x = features, y = possible.features)
  if (length(x = bad.features) > 0) {
    warning("The following requested features are not present in any models: ",
            paste(bad.features, collapse = ", "),
            call. = FALSE
    )
    features <- intersect(x = features, y = possible.features)
  }
  features.orig <- features
  if (na.rm) {
    # only compute residuals when feature model info is present in all
    features <- names(x = which(x = table(unlist(x = lapply(
      X = sct.models,
      FUN = function(x) {
        rownames(x = SCTResults(object = object[[assay]], slot = "feature.attributes", model = x))
      }
    ))) == length(x = sct.models)))
    if (length(x = features) == 0) {
      return(object)
    }
  }

  features <- intersect(x = features.orig, y = features)
  if (length(features) < 1){
    warning("The following requested features are not present in all the models: ",
            paste(features.orig, collapse = ", "),
            call. = FALSE
    )
    return(NULL)
  }  #if (length(x = sct.models) > 1 & verbose) {
  #  message("This SCTAssay contains multiple SCT models. Computing residuals for cells using")
  #}

  # Get all (count) layers
  layers <- Layers(object = object[[umi.assay]], search = layer)

  # iterate over layer running sct model for each of the object names
  new.residuals <- list()
  total_cells <- 0
  all_cells <- c()
  if (!is.null(x = reference.SCT.model)) {
    if (inherits(x = reference.SCT.model, what = "SCTModel")) {
      reference.SCT.model <- SCTModel_to_vst(SCTModel = reference.SCT.model)
    }
    if (is.list(x = reference.SCT.model) & inherits(x = reference.SCT.model[[1]], what = "SCTModel")) {
      stop("reference.SCT.model must be one SCTModel rather than a list of SCTModel")
    }
    if (reference.SCT.model$model_str != "y ~ log_umi") {
      stop("reference.SCT.model must be derived using default SCT regression formula, `y ~ log_umi`")
    }
  }
  for (i in seq_along(along.with = layers)) {
    l <- layers[i]
    sct_model <- sct.models[[i]]
    # these cells belong to this layer
    layer_cells <- Cells(x = object[[umi.assay]], layer = l)
    all_cells <- c(all_cells, layer_cells)
    total_cells <- total_cells + length(layer_cells)
    # calculate residual using this model and these cells
    new.residuals[[i]] <- FetchResidualSCTModel(
      object = object,
      umi.assay = umi.assay,
      assay = assay,
      layer = l,
      layer.cells = layer_cells,
      SCTModel = sct_model,
      reference.SCT.model = reference.SCT.model,
      new_features = features,
      replace.value = replace.value,
      clip.range = clip.range,
      verbose = verbose
    )
  }

  existing.data <- GetAssayData(object = object, slot = "scale.data", assay = assay)
  all.features <- union(x = rownames(x = existing.data), y = features)
  new.scale <- matrix(
    data = NA,
    nrow = length(x = all.features),
    ncol = total_cells,
    dimnames = list(all.features, all_cells)
  )
  common_cells <- intersect(colnames(new.scale), colnames(existing.data))
  if (nrow(x = existing.data) > 0) {
    new.scale[rownames(x = existing.data), common_cells] <- existing.data[, common_cells]
  }
  if (length(x = new.residuals) == 1 & is.list(x = new.residuals)) {
    new.residuals <- new.residuals[[1]]
  } else {
    new.residuals <- Reduce(cbind, new.residuals)
    #new.residuals <- matrix(data = unlist(new.residuals), nrow = nrow(new.scale) , ncol = ncol(new.scale))
    #colnames(new.residuals) <- colnames(new.scale)
    #rownames(new.residuals) <- rownames(new.scale)
  }
  new.scale[rownames(x = new.residuals), colnames(x = new.residuals)] <- new.residuals

  if (na.rm) {
    new.scale <- new.scale[!rowAnyNAs(x = new.scale), ]
  }

  return(new.scale[features, ])
}

#' Calculate pearson residuals of features not in the scale.data
#' This function is the secondary function under FetchResiduals
#'
#' @param object A seurat object
#' @param assay Name of the assay of the seurat object generated by
#' SCTransform. Default is "SCT"
#' @param umi.assay Name of the assay of the seurat object to fetch
#' UMIs from. Default is "RNA"
#' @param layer Name of the layer under `umi.assay` to fetch UMIs from.
#' Default is "counts"
#' @param chunk_size Number of cells to load in memory for calculating
#' residuals
#' @param layer.cells Vector of cells to calculate the residual for.
#' Default is NULL which uses all cells in the layer
#' @param SCTModel Which SCTmodel to use from the object for calculating
#' the residual. Will be ignored if reference.SCT.model is set
#' @param reference.SCT.model If a reference SCT model should be used
#' for calculating the residuals. When set to not NULL, ignores the `SCTModel`
#' paramater.
#' @param new_features A vector of features to calculate the residuals for
#' @param clip.range Numeric of length two specifying the min and max values
#' the Pearson residual will be clipped to. Useful if you want to change the
#' clip.range.
#' @param replace.value Whether to replace the value of residuals if it
#' already exists
#' @param verbose Whether to print messages and progress bars
#'
#' @return Returns a matrix containing centered pearson residuals of
#' added features
#'
#' @importFrom sctransform get_residuals
#' @importFrom Matrix colSums
#'
#' @keywords internal
FetchResidualSCTModel <- function(
  object,
  assay = "SCT",
  umi.assay = "RNA",
  layer = "counts",
  chunk_size = 2000,
  layer.cells = NULL,
  SCTModel = NULL,
  reference.SCT.model = NULL,
  new_features = NULL,
  clip.range = NULL,
  replace.value = FALSE,
  verbose = FALSE
) {

  model.cells <- character()
  model.features <- Features(x = object, assay = assay)
  if (is.null(x = reference.SCT.model)){
    clip.range <- clip.range %||% SCTResults(object = object[[assay]], slot = "clips", model = SCTModel)$sct
    model.features <- rownames(x = SCTResults(object = object[[assay]], slot = "feature.attributes", model = SCTModel))
    model.cells <- Cells(x = slot(object = object[[assay]], name = "SCTModel.list")[[SCTModel]])
    sct.method <- SCTResults(object = object[[assay]], slot = "arguments", model = SCTModel)$sct.method %||% "default"
  }

  layer.cells <- layer.cells %||% Cells(x = object[[umi.assay]], layer = layer)
  if (!is.null(reference.SCT.model)) {
    # use reference SCT model
    sct.method <- "reference"
  }
  existing.scale.data <- NULL
  if (is.null(x=reference.SCT.model)){
    existing.scale.data <- suppressWarnings(GetAssayData(object = object, assay = assay, slot = "scale.data"))
  }
  scale.data.cells <- colnames(x = existing.scale.data)
  scale.data.cells.common <- intersect(scale.data.cells, layer.cells)
  scale.data.cells <- intersect(x = scale.data.cells, y = scale.data.cells.common)
  if (length(x = setdiff(x = layer.cells, y = scale.data.cells)) == 0) {
    # existing.scale.data <- suppressWarnings(GetAssayData(object = object, assay = assay, slot = "scale.data"))
    #full.scale.data <- matrix(data = NA, nrow = nrow(x = existing.scale.data),
    #                          ncol = length(x = layer.cells), dimnames = list(rownames(x = existing.scale.data), layer.cells))
    #full.scale.data[rownames(x = existing.scale.data), colnames(x = existing.scale.data)] <- existing.scale.data
    #existing_features <- names(x = which(x = !apply(
    #  X = full.scale.data,
    #  MARGIN = 1,
    #  FUN = anyNA
    #)))
    existing_features <- rownames(x = existing.scale.data)
  } else {
    existing_features <- character()
  }
  if (replace.value) {
    features_to_compute <- new_features
  } else {
    features_to_compute <- setdiff(x = new_features, y = existing_features)
  }
  if (length(features_to_compute)<1){
    return (existing.scale.data[intersect(x = rownames(x = scale.data.cells), y = new_features),,drop=FALSE])
  }

  if (is.null(x = reference.SCT.model) & length(x = setdiff(x = model.cells, y =  scale.data.cells)) == 0) {
    existing_features <- names(x = which(x = ! apply(
      X = GetAssayData(object = object, assay = assay, slot = "scale.data")[, model.cells],
      MARGIN = 1,
      FUN = anyNA)
    ))
  } else {
    existing_features <- character()
  }
  if (sct.method == "reference.model") {
    if (verbose) {
      message("sct.model ", SCTModel, " is from reference, so no residuals will be recalculated")
    }
    features_to_compute <- character()
  }
  if (!umi.assay %in% Assays(object = object)) {
    warning("The umi assay (", umi.assay, ") is not present in the object. ",
            "Cannot compute additional residuals.",
            call. = FALSE, immediate. = TRUE
    )
    return(NULL)
  }
  # these features do not have feature attriutes
  diff_features <- setdiff(x = features_to_compute, y = model.features)
  intersect_features <- intersect(x = features_to_compute, y = model.features)
  if (sct.method == "reference") {
    vst_out <- SCTModel_to_vst(SCTModel = reference.SCT.model)

    # override clip.range
    clip.range <- vst_out$arguments$sct.clip.range
    umi.field <- paste0("nCount_", assay)
    # get rid of the cell attributes
    vst_out$cell_attr <- NULL
    all.features <- intersect(
      x = rownames(x = vst_out$gene_attr),
      y = features_to_compute
    )
    vst_out$gene_attr <- vst_out$gene_attr[all.features, , drop = FALSE]
    vst_out$model_pars_fit <- vst_out$model_pars_fit[all.features, , drop = FALSE]
  } else {
    vst_out <- SCTModel_to_vst(SCTModel = slot(object = object[[assay]], name = "SCTModel.list")[[SCTModel]])
    clip.range <- vst_out$arguments$sct.clip.range
  }
  clip.max <- max(clip.range)
  clip.min <- min(clip.range)

  layer.cells <- layer.cells %||% Cells(x = object[[umi.assay]], layer = layer)
  if (length(x = diff_features) == 0) {
    counts <- LayerData(
      object = object[[umi.assay]],
      layer = layer,
      cells = layer.cells
    )
    cells.vector <- 1:length(x = layer.cells)
    cells.grid <- split(x = cells.vector, f = ceiling(x = seq_along(along.with = cells.vector)/chunk_size))
    new_residuals <- list()

    for (i in seq_len(length.out = length(x = cells.grid))) {
      vp <- cells.grid[[i]]
      block <- counts[,vp, drop=FALSE]
      umi.all <- as.sparse(x = block)

      # calculate min_variance for get_residuals
      # required when vst_out$arguments$min_variance == "umi_median"
      # only calculated once
      if (i==1){
        nz_median <- median(umi.all@x)
        min_var_custom <- (nz_median / 5)^2
      }
      umi <- umi.all[features_to_compute, , drop = FALSE]

      ## Add cell_attr for missing cells
      cell_attr <- data.frame(
        umi = colSums(umi.all),
        log_umi = log10(x = colSums(umi.all))
      )
      rownames(cell_attr) <- colnames(umi.all)
      if (sct.method %in% c("reference.model", "reference")) {
        vst_out$cell_attr <- cell_attr[colnames(umi.all), ,drop=FALSE]
      } else {
        cell_attr_existing <- vst_out$cell_attr
        cells_missing <- setdiff(rownames(cell_attr), rownames(cell_attr_existing))
        if (length(cells_missing)>0){
          cell_attr_missing <- cell_attr[cells_missing, ,drop=FALSE]
          missing_cols <- setdiff(x = colnames(x = cell_attr_existing),
                                  y = colnames(x = cell_attr_missing))

          if (length(x = missing_cols) > 0) {
            cell_attr_missing[, missing_cols] <- NA
          }
          vst_out$cell_attr <- rbind(cell_attr_existing,
                                     cell_attr_missing)
          vst_out$cell_attr <- vst_out$cell_attr[colnames(umi), , drop=FALSE]
        }
      }
      if (verbose) {
        if (sct.method == "reference.model") {
          message("using reference sct model")
        } else {
          message("sct.model: ", SCTModel, " on ", ncol(x = umi), " cells: ",
                  colnames(x = umi.all)[1], " .. ", colnames(x = umi.all)[ncol(umi.all)])
        }
      }

      if (vst_out$arguments$min_variance == "umi_median"){
        min_var <- min_var_custom
      } else {
        min_var <- vst_out$arguments$min_variance
      }
      if (nrow(umi)>0){
        vst_out.tmp <- vst_out
        vst_out.tmp$cell_attr <- vst_out.tmp$cell_attr[colnames(x = umi),]
        new_residual <- get_residuals(
          vst_out = vst_out.tmp,
          umi = umi,
          residual_type = "pearson",
          min_variance = min_var,
          res_clip_range = c(clip.min, clip.max),
          verbosity = as.numeric(x = verbose) * 2
        )
      } else {
        return(matrix(
          data = NA,
          nrow = length(x = features_to_compute),
          ncol = length(x = colnames(umi.all)),
          dimnames = list(features_to_compute, colnames(umi.all))
        ))
      }
      new_residual <- as.matrix(x = new_residual)
      new_residuals[[i]] <- new_residual
    }
    new_residual <- do.call(what = cbind, args = new_residuals)
    # centered data if no reference model is provided
    if (is.null(x = reference.SCT.model)){
      new_residual <- new_residual - rowMeans(x = new_residual)
    } else {
      # subtract residual mean from reference model
      if (verbose){
        message("Using residual mean from reference for centering")
      }
      vst_out <- SCTModel_to_vst(SCTModel = reference.SCT.model)
      ref.residuals.mean <- vst_out$gene_attr[rownames(x = new_residual),"residual_mean"]
      new_residual <- sweep(
        x = new_residual,
        MARGIN = 1,
        STATS = ref.residuals.mean,
        FUN = "-"
      )
    }
    # return (new_residuals)
  } else {
    #  Some features do not exist
    warning(
      "In the SCTModel ", SCTModel, ", the following ", length(x = diff_features),
      " features do not exist in the counts slot: ", paste(diff_features, collapse = ", ")
    )
    if (length(x = intersect_features) == 0) {
      # No features exist
      return(matrix(
        data = NA,
        nrow = length(x = features_to_compute),
        ncol = length(x = model.cells),
        dimnames = list(features_to_compute, model.cells)
      ))
    }
  }
  old.features <- setdiff(x = new_features, y = features_to_compute)
  if (length(x = old.features) > 0) {
    old_residuals <- GetAssayData(object = object[[assay]], slot = "scale.data")[old.features, model.cells, drop = FALSE]
    new_residual <- rbind(new_residual, old_residuals)[new_features, ]
  }
  return(new_residual)
}

#' @importFrom sctransform get_residuals
GetResidualsChunked <- function(vst_out, layer.counts, residual_type, min_variance, res_clip_range, verbose, chunk_size=5000) {
  if (inherits(x = layer.counts, what = 'V3Matrix')) {
    residuals <- get_residuals(
      vst_out = vst_out,
      umi = layer.counts,
      residual_type = residual_type,
      min_variance = min_variance,
      res_clip_range = res_clip_range,
      verbosity = as.numeric(x = verbose) * 2
    )
  } else if (inherits(x = layer.counts, what = "IterableMatrix")) {
    cells.vector <- 1:ncol(x = layer.counts)
    residuals.list <- list()
    cells.grid <- split(x = cells.vector, f = ceiling(x = seq_along(along.with = cells.vector)/chunk_size))
    for (i in seq_len(length.out = length(x = cells.grid))) {
      vp <- cells.grid[[i]]
      counts.vp <- as.sparse(x = layer.counts[, vp])
      vst.out <- vst_out
      vst.out$cell_attr <- vst.out$cell_attr[colnames(x = counts.vp),,drop=FALSE]
      residuals.list[[i]] <- get_residuals(
        vst_out = vst.out,
        umi = counts.vp,
        residual_type = residual_type,
        min_variance = min_variance,
        res_clip_range = res_clip_range,
        verbosity = as.numeric(x = verbose) * 2
      )
    }
    residuals <- Reduce(f = cbind, x = residuals.list)
  } else {
    stop("Data type not supported")
  }
  return (residuals)
}

#' temporal function to get residuals from reference
#' @param object A seurat object
#' @param reference.SCT.model a reference SCT model that should be used
#' for calculating the residuals
#' @param features Names of features to compute
#' @param nCount_UMI UMI counts. If not specified, defaults to
#' column sums of object
#' @param verbose Whether to print messages and progress bars
#' @importFrom sctransform get_residuals
#' @importFrom Matrix colSums
#'
#' @keywords internal
FetchResiduals_reference <- function(object,
                                     reference.SCT.model = NULL,
                                     features = NULL,
                                     nCount_UMI = NULL,
                                     verbose = FALSE) {
  ## Add cell_attr for missing cells
  nCount_UMI <- nCount_UMI %||% colSums(object)
  cell_attr <- data.frame(
    umi = nCount_UMI,
    log_umi = log10(x = nCount_UMI)
  )
  features_to_compute <- features
  features_to_compute <- intersect(features_to_compute, rownames(object))
  vst_out <- SCTModel_to_vst(SCTModel = reference.SCT.model)

  # override clip.range
  clip.range <- vst_out$arguments$sct.clip.range
  # get rid of the cell attributes
  vst_out$cell_attr <- NULL
  all.features <- intersect(
    x = rownames(x = vst_out$gene_attr),
    y = features_to_compute
  )
  vst_out$gene_attr <- vst_out$gene_attr[all.features, , drop = FALSE]
  vst_out$model_pars_fit <- vst_out$model_pars_fit[all.features, , drop = FALSE]

  clip.max <- max(clip.range)
  clip.min <- min(clip.range)


  umi <- object[features_to_compute, , drop = FALSE]


  rownames(cell_attr) <- colnames(object)
  vst_out$cell_attr <- cell_attr

  if (verbose) {
    message("using reference sct model")
  }

  if (vst_out$arguments$min_variance == "umi_median"){
    nz_median <- 1
    min_var_custom <- (nz_median / 5)^2
    min_var <- min_var_custom
  } else {
    min_var <- vst_out$arguments$min_variance
  }
  new_residual <- get_residuals(
    vst_out = vst_out,
    umi = umi,
    residual_type = "pearson",
    min_variance = min_var,
    verbosity = as.numeric(x = verbose) * 2
  )

  ref.residuals.mean <- vst_out$gene_attr[rownames(x = new_residual),"residual_mean"]
  new_residual <- sweep(
    x = new_residual,
    MARGIN = 1,
    STATS = ref.residuals.mean,
    FUN = "-"
  )
  new_residual <- MinMax(data = new_residual, min = clip.min, max = clip.max)
  return(new_residual)
}

#' Find variable features based on mean.var.plot
#'
#' @param data Data matrix
#' @param nselect Number of features to select based on dispersion values
#' @param verbose Whether to print messages and progress bars
#' @param mean.cutoff Numeric of length two specifying the min and max values
#' @param dispersion.cutoff Numeric of length two specifying the min and max values
#'
#' @keywords internal
#'
MVP <- function(
  data,
  verbose = TRUE,
  nselect = 2000L,
  mean.cutoff = c(0.1, 8),
  dispersion.cutoff = c(1, Inf),
  ...
) {
  hvf.info <- DISP(data = data, nselect = nselect, verbose = verbose)
  hvf.info$variable <- FALSE
  hvf.info$rank <- NA
  hvf.info <- hvf.info[order(hvf.info$mvp.dispersion, decreasing = TRUE), , drop = FALSE]
  means.use <- (hvf.info[, 1] > mean.cutoff[1]) & (hvf.info[, 1] < mean.cutoff[2])
  dispersions.use <- (hvf.info[, 3] > dispersion.cutoff[1]) & (hvf.info[, 3] < dispersion.cutoff[2])
  hvf.info[which(x = means.use & dispersions.use), 'variable'] <- TRUE
  rank.rows <- rownames(x = hvf.info)[which(x = means.use & dispersions.use)]
  selected.indices <- which(rownames(x = hvf.info) %in% rank.rows)
  hvf.info$rank[selected.indices] <- seq_along(selected.indices)
  hvf.info <- hvf.info[order(as.numeric(row.names(hvf.info))), ]
  # hvf.info[hvf.info$variable,'rank'] <- rank(x = hvf.info[hvf.info$variable,'rank'])
  # hvf.info[!hvf.info$variable,'rank'] <- NA
  return(hvf.info)
}
